{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from dotenv import load_dotenv\n",
    "import os\n",
    "from pathlib import Path\n",
    "import json\n",
    "from openai import OpenAI\n",
    "from tqdm import tqdm\n",
    "import re\n",
    "from glob import glob\n",
    "from mistralai import Mistral\n",
    "from mistral_batch import MistralAIBatchProcessor\n",
    "from time import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('survey_prompt.txt', 'r') as f:\n",
    "    survey_prompt = f.read()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "load_dotenv(dotenv_path = '../APIS/.env')\n",
    "os.environ[\"MISTRAL_API_KEY\"] = os.getenv('MISTRAL_API_KEY')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = \"mistral-large-latest\"\n",
    "\n",
    "client = Mistral(api_key=os.getenv('MISTRAL_API_KEY'))\n",
    "\n",
    "chat_response = client.chat.complete(\n",
    "    model= model,\n",
    "    messages = [\n",
    "        {\n",
    "            \"role\": \"user\",\n",
    "            \"content\": survey_prompt,\n",
    "        },\n",
    "    ]\n",
    ")\n",
    "# print(chat_response.choices[0].message.content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "chat_response.usage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "messages = [{\"role\": \"user\", \"content\": survey_prompt}]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create an instance of the batch processor\n",
    "batch_processor = MistralAIBatchProcessor(\n",
    "    model_name=\"mistral-large-latest\",  # Specify the OpenAI model\n",
    "    max_tokens=1600,            # Max tokens per response\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 1: Create tasks for batch processing\n",
    "tasks = [batch_processor.create_task(id=i, messages=messages) for i in range(9000)]  # Create 5 tasks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 2: Write the tasks to a file\n",
    "batch_processor.write_task_file(tasks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 1000\n",
    "num_files = (len(tasks) + batch_size - 1) // batch_size\n",
    "print(f'Total requests: {len(tasks)}. Batch size: {batch_size}. Separated into {num_files} files.\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Iterate over chunks of data\n",
    "for i in range(num_files):\n",
    "    print(f\"Processing batch {i+1}/{num_files}...\", end=\"\\n\\n\")\n",
    "\n",
    "    start_index = i * batch_size\n",
    "    end_index = min(start_index + batch_size, len(tasks))  # Avoid out-of-range slicing\n",
    "\n",
    "    # Slice the list to get the current batch\n",
    "    batch_data = tasks[start_index:end_index]\n",
    "\n",
    "    # Ensure we are not writing an empty batch\n",
    "    if not batch_data:\n",
    "        print(f\"Skipping batch {i+1} as it's empty.\")\n",
    "        continue\n",
    "\n",
    "    # Generate batch ID\n",
    "    batch_id = int(time())\n",
    "\n",
    "    # Write batch tasks\n",
    "    batch_processor.write_batch_file(batch_data, batch_id)\n",
    "    print(f\"Batch {batch_id} written successfully.\", end=\"\\n\\n\")\n",
    "\n",
    "    # Upload batch file\n",
    "    batch_file = batch_processor.upload_batch_file(batch_id)\n",
    "    if not batch_file:\n",
    "        print(f\"Failed to upload file for batch {batch_id}. Skipping this batch.\")\n",
    "        continue  # Instead of breaking, we skip and move to the next batch\n",
    "\n",
    "    # Create batch job\n",
    "    batch_job = batch_processor.create_batch_job(batch_file)\n",
    "    if not batch_job:\n",
    "        print(f\"Failed to create batch job for batch {batch_id}. Skipping this batch.\")\n",
    "        continue\n",
    "\n",
    "    # Monitor batch job status until completion\n",
    "    final_status = batch_processor.check_batch_job_status(batch_job.id, check_interval=2)\n",
    "    \n",
    "\n",
    "print(\"\\nBatch processing completed.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 1: Retrieve the last batch output ID\n",
    "output_files = []\n",
    "\n",
    "# Retrieve the latest batches\n",
    "for idx, batch in enumerate(client.batch.jobs.list(status=\"SUCCESS\")):\n",
    "    if idx == 1:\n",
    "        for batch_file in batch[1]:\n",
    "            print(f\"Batch ID: {batch_file.id}, Status: {batch_file.status}\")\n",
    "        \n",
    "            batch_id, created_at, output_file_id = batch_file.id, batch_file.created_at, batch_file.output_file\n",
    "            output_files.append([batch_id, created_at, output_file_id])\n",
    "\n",
    "# Step 3: Print summary\n",
    "print(f\"Total completed batches retrieved: {len(output_files)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, batch_info in enumerate(output_files):\n",
    "    batch_id, created_at, output_file_id = batch_info\n",
    "    print(f\"Processing completed batch {batch_id} (Created: {created_at}) with Output File ID: {output_file_id}\")\n",
    "\n",
    "    batch_processor.save_batch_output(output_file_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_file_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def download_file(client, file_id, output_path):\n",
    "\n",
    "    if file_id is not None:\n",
    "        print(f\"Downloading file to {output_path}\")\n",
    "        output_file = client.files.download(file_id=file_id)\n",
    "        with open(output_path, \"w\") as f:\n",
    "            for chunk in output_file.stream:\n",
    "                f.write(chunk.decode(\"utf-8\"))\n",
    "        print(f\"Downloaded file to {output_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "download_file(client, file_id='fc1150e4-57ee-4ae8-bd22-658e4bee9575', output_path='mistral_batch_outputs/mistral_batch_output0.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
